{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Initial Setups"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Google Colab use only)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use Google Colab\n",
    "use_colab = True\n",
    "\n",
    "# Is this notebook running on Colab?\n",
    "# If so, then google.colab package (github.com/googlecolab/colabtools)\n",
    "# should be available in this environment\n",
    "\n",
    "# Previous version used importlib, but we could do the same thing with\n",
    "# just attempting to import google.colab\n",
    "try:\n",
    "    from google.colab import drive\n",
    "    colab_available = True\n",
    "except:\n",
    "    colab_available = False\n",
    "\n",
    "if use_colab and colab_available:\n",
    "    drive.mount('/content/drive')\n",
    "\n",
    "    # cd to the appropriate working directory under my Google Drive\n",
    "    %cd '/content/drive/My Drive/cs696ds_lexalytics/Prompting Experiments'\n",
    "    \n",
    "    # Install packages specified in requirements\n",
    "    !pip install -r requirements.txt\n",
    "    \n",
    "    # List the directory contents\n",
    "    !ls"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# We will use the following string ID to identify this particular (training) experiments\n",
    "# in directory paths and other settings\n",
    "experiment_id = 'bert_prompt_logit_softmax_atsc_laptops_bert-base-uncased_multiple_prompts'\n",
    "\n",
    "# Random seed\n",
    "random_seed = 696\n",
    "\n",
    "# path to pretrained MLM model folder or the string \"bert-base-uncased\"\n",
    "lm_model_path = 'bert-base-uncased'\n",
    "\n",
    "# Prompts to be added to the end of each review text\n",
    "# Note: pseudo-labels for each prompt should be given in the order of (positive), (negative), (neutral)\n",
    "sentiment_prompts = [\n",
    "    {\"prompt\": \"I felt the {aspect} was [MASK].\", \"labels\": [\"good\", \"bad\", \"ok\"]},\n",
    "    {\"prompt\": \"I [MASK] the {aspect}.\", \"labels\": [\"love\", \"hate\", \"dislike\"]},\n",
    "    {\"prompt\": \"The {aspect} made me feel [MASK].\", \"labels\": [\"good\", \"bad\", \"indifferent\"]},\n",
    "    {\"prompt\": \"The {aspect} is [MASK].\", \"labels\": [\"good\", \"bad\", \"ok\"]}\n",
    "]\n",
    "\n",
    "# Multiple prompt merging behavior\n",
    "prompts_merge_behavior = 'sum_logits'\n",
    "\n",
    "# Perturb the input embeddings of tokens within the prompts\n",
    "prompts_perturb = False\n",
    "\n",
    "# Training settings\n",
    "training_epochs = 20\n",
    "training_batch_size = 16\n",
    "training_learning_rate = 2e-5\n",
    "training_weight_decay = 0.01\n",
    "training_warmup_steps_duration = 0.15\n",
    "training_hard_restart_num_cycles = 3\n",
    "training_max_grad_norm = 1.0\n",
    "training_best_model_criterion = 'train_loss'\n",
    "training_domain = 'laptops' # 'laptops', 'restaurants', 'joint'\n",
    "training_dataset_proportion = 1.0\n",
    "training_dataset_few_shot_size = None\n",
    "\n",
    "training_lm_freeze = False\n",
    "\n",
    "validation_enabled = False\n",
    "validation_dataset_proportion = 0.2\n",
    "validation_batch_size = 16\n",
    "\n",
    "testing_batch_size = 16\n",
    "testing_domain = 'laptops' # 'laptops', 'restaurants', 'joint'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Batch size adjustment for multiple prompts.\n",
    "training_batch_size = training_batch_size // len(sentiment_prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Experiment ID:\", experiment_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Package imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import random\n",
    "import shutil\n",
    "import copy\n",
    "import inspect\n",
    "import json\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import transformers\n",
    "import datasets\n",
    "import sklearn.metrics\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sn\n",
    "import tqdm\n",
    "\n",
    "current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))\n",
    "parent_dir = os.path.dirname(current_dir)\n",
    "sys.path.append(parent_dir)\n",
    "\n",
    "import utils\n",
    "\n",
    "# Random seed settings\n",
    "random.seed(random_seed)\n",
    "np.random.seed(random_seed)\n",
    "\n",
    "# cuBLAS reproducibility\n",
    "# https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n",
    "os.environ['CUBLAS_WORKSPACE_CONFIG'] = \":4096:8\"\n",
    "torch.set_deterministic(True)\n",
    "torch.manual_seed(random_seed)\n",
    "\n",
    "# Print version information\n",
    "print(\"Python version: \" + sys.version)\n",
    "print(\"NumPy version: \" + np.__version__)\n",
    "print(\"PyTorch version: \" + torch.__version__)\n",
    "print(\"Transformers version: \" + transformers.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PyTorch GPU settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():    \n",
    "    torch_device = torch.device('cuda')\n",
    "\n",
    "    # Set this to True to make your output immediately reproducible\n",
    "    # Note: https://pytorch.org/docs/stable/notes/randomness.html\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    \n",
    "    # Disable 'benchmark' mode: Set this False if you want to measure running times more fairly\n",
    "    # Note: https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "    \n",
    "    # Faster Host to GPU copies with page-locked memory\n",
    "    use_pin_memory = True\n",
    "    \n",
    "    # Number of compute devices to be used for training\n",
    "    training_device_count = torch.cuda.device_count()\n",
    "\n",
    "    # CUDA libraries version information\n",
    "    print(\"CUDA Version: \" + str(torch.version.cuda))\n",
    "    print(\"cuDNN Version: \" + str(torch.backends.cudnn.version()))\n",
    "    print(\"CUDA Device Name: \" + str(torch.cuda.get_device_name()))\n",
    "    print(\"CUDA Capabilities: \"+ str(torch.cuda.get_device_capability()))\n",
    "    print(\"Number of CUDA devices: \"+ str(training_device_count))\n",
    "    \n",
    "else:\n",
    "    torch_device = torch.device('cpu')\n",
    "    use_pin_memory = False\n",
    "    \n",
    "    # Number of compute devices to be used for training\n",
    "    training_device_count = 1\n",
    "\n",
    "print()\n",
    "print(\"PyTorch device selected:\", torch_device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare Datasets for Prompt-based Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the SemEval dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load semeval for both domains\n",
    "laptops_dataset = datasets.load_dataset(\n",
    "    os.path.abspath('../dataset_scripts/semeval2014_task4/semeval2014_task4.py'),\n",
    "    data_files={\n",
    "        'test': '../dataset_files/semeval_2014/Laptops_Test_Gold.xml',\n",
    "        'train': '../dataset_files/semeval_2014/Laptop_Train_v2.xml',\n",
    "    },\n",
    "    cache_dir='../dataset_cache')\n",
    "\n",
    "restaurants_dataset = datasets.load_dataset(\n",
    "    os.path.abspath('../dataset_scripts/semeval2014_task4/semeval2014_task4.py'),\n",
    "    data_files={\n",
    "        'test': '../dataset_files/semeval_2014/Restaurants_Test_Gold.xml',\n",
    "        'train': '../dataset_files/semeval_2014/Restaurants_Train_v2.xml',\n",
    "    },\n",
    "    cache_dir='../dataset_cache')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The dataset chosen for training/testing\n",
    "print(\"Training domain:\", training_domain)\n",
    "print(\"Testing domain:\", testing_domain)\n",
    "\n",
    "if training_domain == 'laptops':\n",
    "    train_set = laptops_dataset['train']\n",
    "elif training_domain == 'restaurants':\n",
    "    train_set = restaurants_dataset['train']\n",
    "elif training_domain == 'joint':\n",
    "    train_set = laptops_dataset['train'] + restaurants_dataset['train']\n",
    "\n",
    "if testing_domain == 'laptops':\n",
    "    test_set = laptops_dataset['test']\n",
    "elif testing_domain == 'restaurants':\n",
    "    test_set = restaurants_dataset['test']\n",
    "elif testing_domain == 'joint':\n",
    "    test_set = laptops_dataset['test'] + restaurants_dataset['test']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train-validation split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training set size after validation split\n",
    "if validation_enabled:\n",
    "    new_train_dataset_size = int(int(len(train_set) * training_dataset_proportion) * (1 - validation_dataset_proportion))\n",
    "    new_valid_dataset_size = int(len(train_set) * training_dataset_proportion) - new_train_dataset_size\n",
    "\n",
    "    print(\"Training dataset after split:\", new_train_dataset_size)\n",
    "    print(\"Validation dataset after split:\", new_train_dataset_size)\n",
    "else:\n",
    "    new_train_dataset_size = int(len(train_set) * training_dataset_proportion)\n",
    "\n",
    "    print(\"Training dataset size:\", new_train_dataset_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_set = train_set.shuffle(seed=random_seed)\n",
    "\n",
    "new_train_set = train_set.select(indices=np.arange(new_train_dataset_size))\n",
    "\n",
    "if validation_enabled:\n",
    "    new_valid_set = train_set.select(\n",
    "        indices=np.arange(\n",
    "            new_train_dataset_size,\n",
    "            new_train_dataset_size + new_valid_dataset_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For few shot experiments, select first few examples in new_train_set for training\n",
    "if training_dataset_few_shot_size is not None:\n",
    "    new_train_set = new_train_set.select(indices=np.arange(training_dataset_few_shot_size))\n",
    "    \n",
    "    print(\"Few-shot training dataset size:\", len(new_train_set))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(new_train_set[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zero-shot ATSC with Prompts + MLM Output Head"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the pretrained LM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load pretrained language model\n",
    "lm = transformers.AutoModelForMaskedLM.from_pretrained(lm_model_path)\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained('bert-base-uncased', cache_dir='../bert_base_cache')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Freeze the MLM main layer and leave the MLM head trainable\n",
    "# Note: Since input and output word embeddings are tied,\n",
    "# the output word embedding layer in lm.cls will remain untrainable\n",
    "# https://github.com/huggingface/transformers/blob/master/src/transformers/configuration_utils.py#L170\n",
    "if training_lm_freeze:\n",
    "    for param in lm.cls.parameters():\n",
    "        param.requires_grad = True\n",
    "\n",
    "    for param in lm.bert.parameters():\n",
    "        param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a new model with MLM output head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Encode the pseudo-label words for each sentiment class\n",
    "sentiment_word_ids = []\n",
    "\n",
    "for sp in sentiment_prompts:\n",
    "    sentiment_word_ids.append(\n",
    "        [tokenizer.convert_tokens_to_ids(w) for w in sp['labels']])\n",
    "\n",
    "print(sentiment_word_ids)\n",
    "\n",
    "classifier_model = utils.MultiPromptLogitSentimentClassificationHead(\n",
    "    lm=lm,\n",
    "    num_class=3,\n",
    "    num_prompts=len(sentiment_prompts), pseudo_label_words=sentiment_word_ids,\n",
    "    target_token_id=tokenizer.mask_token_id,\n",
    "    merge_behavior=prompts_merge_behavior,\n",
    "    perturb_prompts=prompts_perturb)\n",
    "\n",
    "classifier_model = classifier_model.to(device=torch_device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataloader = torch.utils.data.DataLoader(\n",
    "    new_train_set, shuffle=True, batch_size=training_batch_size, pin_memory=use_pin_memory)\n",
    "\n",
    "if validation_enabled:\n",
    "    validation_dataloader = torch.utils.data.DataLoader(\n",
    "        new_valid_set, batch_size=validation_batch_size, pin_memory=use_pin_memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# How many training steps would we have?\n",
    "training_total_steps = len(train_dataloader) * training_epochs\n",
    "\n",
    "print(\"There will be %d training steps.\" % training_total_steps)\n",
    "\n",
    "# Let's have warmups for the first (training_warmup_steps_duration)% of steps.\n",
    "training_warmup_steps = int(training_total_steps * training_warmup_steps_duration)\n",
    "\n",
    "print(\"Warmup steps:\", training_warmup_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_function = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "no_decay = ['bias', 'LayerNorm.weight']\n",
    "\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in classifier_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': training_weight_decay},\n",
    "    {'params': [p for n, p in classifier_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ]\n",
    "\n",
    "optimizer = transformers.AdamW(\n",
    "    optimizer_grouped_parameters,\n",
    "    lr=training_learning_rate,\n",
    "    weight_decay=training_weight_decay)\n",
    "\n",
    "scheduler = transformers.get_cosine_with_hard_restarts_schedule_with_warmup(\n",
    "    optimizer,\n",
    "    num_warmup_steps=training_warmup_steps,\n",
    "    num_training_steps=training_total_steps,\n",
    "    num_cycles=training_hard_restart_num_cycles\n",
    ")\n",
    "\n",
    "# The directory to save the best version of the head\n",
    "trained_model_directory = os.path.join('..', 'trained_models', experiment_id)\n",
    "\n",
    "shutil.rmtree(trained_model_directory, ignore_errors=True)\n",
    "os.makedirs(trained_model_directory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(predictions, labels):\n",
    "    preds = predictions.argmax(-1)\n",
    "\n",
    "    precision, recall, f1, _ = sklearn.metrics.precision_recall_fscore_support(\n",
    "        y_true=labels, y_pred=preds, labels=[0,1,2], average='macro')\n",
    "\n",
    "    acc = sklearn.metrics.accuracy_score(labels, preds)\n",
    "\n",
    "    return {\n",
    "        'accuracy': acc,\n",
    "        'f1': f1,\n",
    "        'precision': precision,\n",
    "        'recall': recall\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_epoch = -1\n",
    "\n",
    "if training_best_model_criterion in ('train_loss', 'valid_loss'):\n",
    "    best_score = float('inf')\n",
    "else:\n",
    "    best_score = -1\n",
    "\n",
    "for epoch in tqdm.notebook.tqdm(range(int(training_epochs))):\n",
    "\n",
    "    print(\"Training epoch %d\" % epoch)\n",
    "    print()\n",
    "\n",
    "    classifier_model.train()\n",
    "\n",
    "    for batch in tqdm.notebook.tqdm(train_dataloader):\n",
    "\n",
    "        reviews_repeated = []\n",
    "        prompts_populated = []\n",
    "\n",
    "        for prompt in sentiment_prompts:\n",
    "            reviews_repeated = reviews_repeated + batch[\"text\"]\n",
    "\n",
    "            for aspect in batch[\"aspect\"]:\n",
    "                prompts_populated.append(prompt['prompt'].format(aspect=aspect))\n",
    "\n",
    "        batch_encoded = tokenizer(\n",
    "            reviews_repeated, prompts_populated,\n",
    "            padding='max_length', truncation='only_first', max_length=256,\n",
    "            return_tensors='pt')\n",
    "        \n",
    "        batch_encoded = batch_encoded.to(torch_device)\n",
    "\n",
    "        batch_label = batch[\"sentiment\"]\n",
    "        batch_label = batch_label.to(torch_device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        batch_output = classifier_model(batch_encoded)\n",
    "\n",
    "        loss = loss_function(batch_output, batch_label)\n",
    "\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(classifier_model.parameters(), training_max_grad_norm)\n",
    "\n",
    "        optimizer.step()\n",
    "        scheduler.step()\n",
    "        \n",
    "    print(\"Epoch {}, Training Loss: {}\".format(epoch, loss.item()))\n",
    "    print()\n",
    "\n",
    "    # Validate the model using val dataset\n",
    "    if validation_enabled:\n",
    "        with torch.no_grad():\n",
    "            classifier_model.eval()\n",
    "\n",
    "            print(\"Validation epoch %d\" % epoch)\n",
    "            print()\n",
    "\n",
    "            predictions_val = torch.Tensor()\n",
    "            labels_val = torch.Tensor()\n",
    "\n",
    "            for batch_val in tqdm.notebook.tqdm(validation_dataloader):\n",
    "\n",
    "                reviews_repeated = []\n",
    "                prompts_populated = []\n",
    "\n",
    "                for prompt in sentiment_prompts:\n",
    "                    reviews_repeated = reviews_repeated + batch_val[\"text\"]\n",
    "\n",
    "                    for aspect in batch_val[\"aspect\"]:\n",
    "                        prompts_populated.append(prompt['prompt'].format(aspect=aspect))\n",
    "\n",
    "                batch_val_encoded = tokenizer(\n",
    "                    reviews_repeated, prompts_populated,\n",
    "                    padding='max_length', truncation='only_first', max_length=256,\n",
    "                    return_tensors='pt')\n",
    "\n",
    "                batch_val_encoded.to(torch_device)\n",
    "\n",
    "                batch_val_label = batch_val[\"sentiment\"]\n",
    "\n",
    "                batch_val_output = classifier_model(batch_val_encoded)\n",
    "\n",
    "                batch_val_output = batch_val_output.to('cpu')\n",
    "\n",
    "                predictions_val = torch.cat([predictions_val, batch_val_output])\n",
    "                labels_val = torch.cat([labels_val, batch_val_label])\n",
    "\n",
    "            # Compute metrics\n",
    "            validation_loss = torch.nn.functional.cross_entropy(predictions_val, labels_val.long())\n",
    "            validation_metrics = compute_metrics(predictions_val, labels_val)\n",
    "\n",
    "            print(\"Validation Loss: {}, Validation Metrics: {}\".format(validation_loss.item(), validation_metrics))\n",
    "            print()\n",
    "\n",
    "    if training_best_model_criterion == 'train_loss':\n",
    "        epoch_score = loss.item()\n",
    "    elif training_best_model_criterion == 'valid_loss':\n",
    "        epoch_score = validation_loss.item()\n",
    "        \n",
    "    if training_best_model_criterion in ('train_loss', 'valid_loss'):\n",
    "        better_model_found = epoch_score < best_score\n",
    "    else:\n",
    "        better_model_found = epoch_score > best_score\n",
    "\n",
    "    # Save the current epoch's model if the validation loss is lower than the best known so far\n",
    "    if better_model_found:\n",
    "        if best_epoch != -1:\n",
    "            try:\n",
    "                os.remove(os.path.join(trained_model_directory, 'epoch_{}.pt'.format(best_epoch)))\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "        best_score = epoch_score\n",
    "        best_epoch = epoch\n",
    "\n",
    "        torch.save(\n",
    "            classifier_model.state_dict(),\n",
    "            os.path.join(trained_model_directory, 'epoch_{}.pt'.format(epoch)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation with in-domain test set\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataloader = torch.utils.data.DataLoader(\n",
    "    test_set, batch_size=testing_batch_size, pin_memory=use_pin_memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the best found head weights\n",
    "with torch.no_grad():\n",
    "    print('Loading epoch {}'.format(best_epoch))\n",
    "\n",
    "    classifier_model.load_state_dict(torch.load(\n",
    "        os.path.join(trained_model_directory, 'epoch_{}.pt'.format(best_epoch)),\n",
    "        map_location=torch_device))\n",
    "\n",
    "    classifier_model.eval()\n",
    "\n",
    "    predictions_test = torch.Tensor()\n",
    "    labels_test = torch.Tensor()\n",
    "\n",
    "    for batch_test in tqdm.notebook.tqdm(test_dataloader):\n",
    "\n",
    "        reviews_repeated = []\n",
    "        prompts_populated = []\n",
    "\n",
    "        for prompt in sentiment_prompts:\n",
    "            reviews_repeated = reviews_repeated + batch_test[\"text\"]\n",
    "\n",
    "            for aspect in batch_test[\"aspect\"]:\n",
    "                prompts_populated.append(prompt['prompt'].format(aspect=aspect))\n",
    "\n",
    "        batch_test_encoded = tokenizer(\n",
    "            reviews_repeated, prompts_populated,\n",
    "            padding='max_length', truncation='only_first', max_length=256,\n",
    "            return_tensors='pt')\n",
    "\n",
    "        batch_test_encoded.to(torch_device)\n",
    "\n",
    "        batch_test_label = batch_test[\"sentiment\"]\n",
    "\n",
    "        batch_test_output = classifier_model(batch_test_encoded)\n",
    "\n",
    "        batch_test_output = batch_test_output.to('cpu')\n",
    "\n",
    "        predictions_test = torch.cat([predictions_test, batch_test_output])\n",
    "        labels_test = torch.cat([labels_test, batch_test_label])\n",
    "\n",
    "    # Compute metrics\n",
    "    test_metrics = compute_metrics(predictions_test, labels_test)\n",
    "\n",
    "    print(test_metrics)\n",
    "    \n",
    "    # Save test_metrics into a file for later processing\n",
    "    with open(os.path.join(trained_model_directory, 'test_metrics.json'), 'w') as test_metrics_json:\n",
    "        json.dump(test_metrics, test_metrics_json)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate metrics and confusion matrix based upon predictions and true labels\n",
    "cm = sklearn.metrics.confusion_matrix(labels_test.detach().numpy(), predictions_test.detach().numpy().argmax(-1))\n",
    "\n",
    "df_cm = pd.DataFrame(\n",
    "    cm,\n",
    "    index=[i for i in [\"positive\", \"negative\", \"neutral\"]],\n",
    "    columns=[i for i in [\"positive\", \"negative\", \"neutral\"]])\n",
    "\n",
    "plt.figure(figsize=(10, 7))\n",
    "\n",
    "ax = sn.heatmap(df_cm, annot=True)\n",
    "\n",
    "ax.set(xlabel='Predicted Label', ylabel='True Label')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "celltoolbar": "Tags",
  "colab": {
   "collapsed_sections": [],
   "name": "prompt_logit_softmax_atsc_single_prompt_i_felt_bert_amazon_electronics.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "04ab9f4d202d4510bf1d5b72274cff8d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_ddbb891369ec437f81e00a47aefc32dd",
       "IPY_MODEL_2b4f63617920492a9b0f421999b88ec9"
      ],
      "layout": "IPY_MODEL_f8506ddaddf0470f85203cfa01349f3d"
     }
    },
    "0b1cc3f1c83d42bdaee67245a070058b": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "0be1f709b9bf4c778e8b576cc1b391b4": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "1794cd2644d243a48627f82ed4eea1ad": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2b4f63617920492a9b0f421999b88ec9": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_0be1f709b9bf4c778e8b576cc1b391b4",
      "placeholder": "​",
      "style": "IPY_MODEL_a40aa7668bcd4f8796a5abca46c13590",
      "value": " 23/23 [00:13&lt;00:00,  1.71it/s]"
     }
    },
    "2d966bc74d8d4c78b5dfe18d91dea381": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "2edf4e3164904407b98fe217a9632abe": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "31a968113d2d4b41949f7c08d9f6ca0c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "31c0d2821f2449b4b87ef333da933918": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_accaa4c642374d9a90d94cf00dde98f4",
      "max": 20,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f7a442ee15784e7fbdf8f82f059bb33a",
      "value": 20
     }
    },
    "3504b977db434bf2a78a8ff8909a6361": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "3ff0a0b85d234c05b6e6827fd4bac876": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_0b1cc3f1c83d42bdaee67245a070058b",
      "placeholder": "​",
      "style": "IPY_MODEL_c12e21d86926469793cdaea2fa3db031",
      "value": " 91/91 [01:15&lt;00:00,  1.21it/s]"
     }
    },
    "40442308fe8f403facfa3558c9e48550": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "4e6e9b6694f542cfa2e8198c68999eb9": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "50fe91c1b38f4f16b7122e8371a911d3": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "5574d6a08ac94dceb4e88a924b3aabfb": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_1794cd2644d243a48627f82ed4eea1ad",
      "placeholder": "​",
      "style": "IPY_MODEL_31a968113d2d4b41949f7c08d9f6ca0c",
      "value": " 23/23 [00:14&lt;00:00,  1.57it/s]"
     }
    },
    "5c13d354d53c42328e4e6f22a73997e3": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "6108e8e60fa34050b744ba4c6ab26120": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "62dab249f9d9493595e26146efaf1ce1": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "64370dad11c74804b13494a6aa2534db": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "66fada69a93547d39175b0cbbb42b8ea": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_40442308fe8f403facfa3558c9e48550",
      "max": 91,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_62dab249f9d9493595e26146efaf1ce1",
      "value": 91
     }
    },
    "6de1bbe40a634c90a957ca0e13e2b0bf": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_3504b977db434bf2a78a8ff8909a6361",
      "max": 2,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_8704563f224149c6b46300b53681c280",
      "value": 2
     }
    },
    "71a600990702485da87ff765724a8c1a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_66fada69a93547d39175b0cbbb42b8ea",
       "IPY_MODEL_73d62910981940e998269dafffea565e"
      ],
      "layout": "IPY_MODEL_2d966bc74d8d4c78b5dfe18d91dea381"
     }
    },
    "73d62910981940e998269dafffea565e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_4e6e9b6694f542cfa2e8198c68999eb9",
      "placeholder": "​",
      "style": "IPY_MODEL_b2cef76b560d4c1c91224a47ad5d728a",
      "value": " 91/91 [01:05&lt;00:00,  1.40it/s]"
     }
    },
    "7ee4bc104d304bda936c6e0be5df219c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_94b1b0cbb90c4568b00b6cb65b36319e",
       "IPY_MODEL_3ff0a0b85d234c05b6e6827fd4bac876"
      ],
      "layout": "IPY_MODEL_94874c5238dd4669af756c80b9fb4036"
     }
    },
    "8704563f224149c6b46300b53681c280": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "88a2df9016a743a48b7d99cf63a328b7": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_31c0d2821f2449b4b87ef333da933918",
       "IPY_MODEL_962e93b7344b4e08aa0eeb6c93830d80"
      ],
      "layout": "IPY_MODEL_64370dad11c74804b13494a6aa2534db"
     }
    },
    "8a9a2c3fe8704d3c8adb4fb33bcc2759": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_91f328571e3b432ea4c536b7560ec61a",
      "max": 23,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_a1c199f42af24afd94557c3e4ebe099d",
      "value": 23
     }
    },
    "91f328571e3b432ea4c536b7560ec61a": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "94874c5238dd4669af756c80b9fb4036": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "94b1b0cbb90c4568b00b6cb65b36319e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_5c13d354d53c42328e4e6f22a73997e3",
      "max": 91,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f266547965be4be3aa18894dd09e9433",
      "value": 91
     }
    },
    "962e93b7344b4e08aa0eeb6c93830d80": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_a13f78e5451a440ba61a2427fef9836d",
      "placeholder": "​",
      "style": "IPY_MODEL_50fe91c1b38f4f16b7122e8371a911d3",
      "value": " 20/20 [00:12&lt;00:00,  1.57it/s]"
     }
    },
    "a13f78e5451a440ba61a2427fef9836d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "a1c199f42af24afd94557c3e4ebe099d": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "a40aa7668bcd4f8796a5abca46c13590": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "accaa4c642374d9a90d94cf00dde98f4": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "afd35f4c4ff445fd9b3ccb21745cd85b": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "b2cef76b560d4c1c91224a47ad5d728a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "b704799e7f9f4f319931aaf722613b99": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c12e21d86926469793cdaea2fa3db031": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "c61e42199d9f41f78e0fd8fdc4c6b606": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "cffb537f6f9e410a8fd1643f4c0ff97e": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_2edf4e3164904407b98fe217a9632abe",
      "placeholder": "​",
      "style": "IPY_MODEL_c61e42199d9f41f78e0fd8fdc4c6b606",
      "value": " 2/2 [02:34&lt;00:00, 77.34s/it]"
     }
    },
    "dce0fd40bcca4d0b82fbbeb676316496": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_8a9a2c3fe8704d3c8adb4fb33bcc2759",
       "IPY_MODEL_5574d6a08ac94dceb4e88a924b3aabfb"
      ],
      "layout": "IPY_MODEL_b704799e7f9f4f319931aaf722613b99"
     }
    },
    "ddbb891369ec437f81e00a47aefc32dd": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "100%",
      "description_tooltip": null,
      "layout": "IPY_MODEL_6108e8e60fa34050b744ba4c6ab26120",
      "max": 23,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_afd35f4c4ff445fd9b3ccb21745cd85b",
      "value": 23
     }
    },
    "ddd7b7342e054f9a92e9adeed966e525": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f266547965be4be3aa18894dd09e9433": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "f7a442ee15784e7fbdf8f82f059bb33a": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    },
    "f8506ddaddf0470f85203cfa01349f3d": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fe5aa2051ed64b52bcbde70278e8a763": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_6de1bbe40a634c90a957ca0e13e2b0bf",
       "IPY_MODEL_cffb537f6f9e410a8fd1643f4c0ff97e"
      ],
      "layout": "IPY_MODEL_ddd7b7342e054f9a92e9adeed966e525"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
